/*
 * Copyright (c) Huawei Technologies Co., Ltd. 2018-2019. All rights reserved.
 * Description: model executor
 *
 */
#include "general_model_executor.h"

#include <thread>

#include "executor/model_executor_factory.h"
#include "control_client.h"
#include "framework/infra/log/log.h"
#include "general_compute/task/task_scheduler.h"
#include "general_compute/task/task_scheduler_factory.h"
#include "general_compute/task/task_node.h"

using namespace std;
using namespace hiai;
using namespace ge;

namespace hiai {

static Status MemoryAllocatorCallBack(std::shared_ptr<hiai::MemoryAllocator> memAllocator)
{
    if (memAllocator == nullptr) {
        return ge::SUCCESS;
    }

    int32_t size = ControlClient::GetInstance().GetExpectValue(ClientKey::CLIENT_Executor_MemAllocate_Size);
    uint64_t allocateSize = (size >= 0) ? size : 0x1000000; /* 0x1000000: 16M */

    void* addr = memAllocator->Allocate(allocateSize, MemoryType::SYSCACHE, 0);
    if (addr == nullptr) {
        return ge::FAIL;
    }

    memAllocator->Free(addr, MemoryType::SYSCACHE);
    return ge::SUCCESS;
}

GeneralModelExecutor::GeneralModelExecutor(uint32_t modelId) : modelId_(modelId)
{
    FMK_LOGI("general executor [%d] construct", modelId_);
}

GeneralModelExecutor::~GeneralModelExecutor()
{
    FMK_LOGI("general executor [%d] destruct", modelId_);
    (void)Finalize();
}

HCS_API_EXPORT Status GeneralModelExecutor::Init(const ModelExecutionOptions& options, const ICompiledModel* model)
{
    FMK_LOGE("+++HERE+++ %s", __func__);
    if (options.memAllocator != nullptr) {
        if (MemoryAllocatorCallBack(options.memAllocator) != ge::SUCCESS) {
            return ge::FAIL;
        }
    }
    return ge::SUCCESS;
}

HCS_API_EXPORT Status GeneralModelExecutor::Reshape(
    const std::vector<std::vector<int64_t>>& inputShape, std::vector<std::vector<int64_t>>& outputShape)
{
    return ge::SUCCESS;
}

HCS_API_EXPORT Status GeneralModelExecutor::CreateFFRTTask(const std::vector<ge::TensorBuffer>& inputs,
    std::vector<ge::TensorBuffer>& outputs, void* ffrtTask)
{
    return ge::SUCCESS;
}

HCS_API_EXPORT Status GeneralModelExecutor::Execute(
    const std::vector<ge::TensorBuffer>& input, std::vector<ge::TensorBuffer>& output)
{
    FMK_LOGE("+++HERE+++ %s", __func__);
    auto taskNode = CreateTaskNode(
        [](TaskNodeContextBasePtr context, const std::vector<const void*>& inDeps = {},
        const std::vector<const void*>& outDeps = {}) {
            (void)context;
            return ge::SUCCESS;
        },
        "empty");
    TaskSchedulerPtr taskScheduler = hiai::TaskSchedulerFactory::Instance()->Create(NORMAL);
    if (taskScheduler == nullptr) {
        return ge::FAIL;
    }
    taskScheduler->Schedule(nullptr, taskNode, taskNode);
    return ge::SUCCESS;
}

HCS_API_EXPORT Status GeneralModelExecutor::ExecuteAsync(const std::vector<ge::TensorBuffer>& input,
    std::vector<ge::TensorBuffer>& output, uint32_t taskId, std::shared_ptr<ExecutionCallback> callback)
{
    auto taskNode = CreateTaskNode(
        [](TaskNodeContextBasePtr context, const std::vector<const void*>& inDeps = {},
        const std::vector<const void*>& outDeps = {}) {
            (void)context;
            return ge::SUCCESS;
        },
        "empty");
    TaskSchedulerPtr taskScheduler = hiai::TaskSchedulerFactory::Instance()->Create(NORMAL);
    if (taskScheduler == nullptr) {
        return ge::FAIL;
    }
    taskScheduler->Schedule(nullptr, taskNode);
    if (callback) {
        (*callback)(modelId_, taskId, ge::SUCCESS);
    }
    return ge::SUCCESS;
}

HCS_API_EXPORT Status GeneralModelExecutor::CancelTask(uint32_t taskId)
{
    bool deleteSuccess = false;
    std::queue<TaskPipeLineContextPtr> tmp;

    std::unique_lock<std::mutex> lock(taskPipeLineContextQueueMtx_);
    while (!taskPipeLineContextQueue_.empty()) {
        if (taskPipeLineContextQueue_.front()->GetTaskId() == taskId) {
            deleteSuccess = true;
        } else {
            tmp.push(taskPipeLineContextQueue_.front());
        }
        taskPipeLineContextQueue_.pop();
    }
    taskPipeLineContextQueue_.swap(tmp);

    if (!deleteSuccess) {
        return ge::FAIL;
    }
    return ge::SUCCESS;
}

HCS_API_EXPORT Status GeneralModelExecutor::Finalize()
{
    return ge::SUCCESS;
}

HCS_API_EXPORT Status GeneralModelExecutor::SetPriority(int32_t priority)
{
    return ge::SUCCESS;
}

HCS_API_EXPORT uint32_t GeneralModelExecutor::GetModelID()
{
    return 1;
}

HCS_API_EXPORT void GeneralModelExecutor::Cancel()
{
    FMK_LOGI("call cancel inference end.");
    return;
}

HCS_API_EXPORT ge::Status GeneralModelExecutor::InitWeights(const std::string& weightDir)
{
    return ge::SUCCESS;
}

HCS_API_EXPORT ge::Status GeneralModelExecutor::GetWeightBuffer(const std::string& weightFileName, ge::TensorBuffer& weight)
{
    return ge::SUCCESS;
}

HCS_API_EXPORT ge::Status GeneralModelExecutor::FlushWeight(const std::string& weightName, size_t offset, size_t size)
{
    return ge::SUCCESS;
}

Status GeneralModelExecutor::InitMemory(
    const std::shared_ptr<MemoryAllocator>& memAllocator, const GeneralCompiledModel* model)
{
    return ge::SUCCESS;
}

Status GeneralModelExecutor::GetInputSizeInfos(const GeneralCompiledModel* model,
    std::vector<TensorSizeInfo>& inputSizeInfos)
{
    return ge::SUCCESS;
}

Status GeneralModelExecutor::GetOutputSizeInfos(const GeneralCompiledModel* model,
    std::vector<TensorSizeInfo>& inputSizeInfos)
{
    return ge::SUCCESS;
}

Status GeneralModelExecutor::CheckInputs(const std::vector<ge::BaseBuffer>& inputs)
{
    return ge::SUCCESS;
}

Status GeneralModelExecutor::CheckOutputs(std::vector<ge::BaseBuffer>& outputs)
{
    return ge::SUCCESS;
}

REGISTER_MODEL_EXECUTOR_CREATOR(IR_GRAPH_MODEL, GeneralModelExecutor);
REGISTER_MODEL_EXECUTOR_CREATOR(OM_STANDARD_MODEL, GeneralModelExecutor);
REGISTER_MODEL_EXECUTOR_CREATOR(IGS_MODEL, GeneralModelExecutor);
REGISTER_MODEL_EXECUTOR_CREATOR(STANDARD_IR_GRAPH_MODEL, GeneralModelExecutor);
REGISTER_MODEL_EXECUTOR_CREATOR(HCS_PARTITION_MODEL, GeneralModelExecutor);
} // namespace hiai
